import csv
import sys

from torch._appdirs import unicode


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        pass

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding='utf-8') as f:
            reader = csv.reader(f, delimiter="\t")
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines

def tsv2obj(filepath,sep='\t'):
    """.tsv->[str]
    调用：
        filepath = 'data/train.tsv'
        rdr=tsv2obj(filepath)
        print(type(rdr),type(rdr[1]),rdr[1])#->[str]
    """
    outlist = []
    with open(filepath, "r") as source:
        rdr = csv.reader(source)
        for r in rdr:  # [一个str],str中按\t分隔
            # wtr.writerow( (r[0], r[1], r[3], r[4]) )
            # print(r,type(r))
            # print(r)
            if len(r) == 1:  # 仅有一个元素
                # l2 = r[0].split(sep)
                outlist.append(r[0])
            else:
                str3 = [str(i) for i in r]  # list转换为str，用[for in];['1', '2', 'hello']
                l2 = ''.join(str3)  # join会去掉中间单引号；12hello#list->str
                outlist.append(l2)
    return outlist


def tsv_getby_col(filepath,outpath,outcol):
    """
    功能：filepath取指定列，写入outpath;
    调用：
        outcol=[0,2,3,4,5]
        filepath='data/train.tsv'
        outpath='data/train2.tsv'
        outlist=tsv_getby_col(filepath,outpath,outcol)
        print(outlist[0])
    返回：outlist ：[str]
    """
    outlist = []
    with open(filepath,"r") as source:
        rdr= csv.reader( source )

        with open(outpath,"w") as result:
            wtr= csv.writer(result)
            for r in rdr:#[一个str],str中按\t分隔
                # wtr.writerow( (r[0], r[1], r[3], r[4]) )
                # print(r,type(r))
                # print(r)
                if len(r)==1:#仅有一个元素
                    l2 = r[0].split('\t')
                    outlist.append(r[0])
                else:
                    str3 = [str(i) for i in r]  # list转换为str，用[for in];['1', '2', 'hello']
                    l2 = ''.join(str3)  # join会去掉中间单引号；12hello#list->str
                    outlist.append(l2)
                    l2=l2.split('\t')

                    # print(l2)
                r = [l2[i] for i in outcol]
                # print(r)
                wtr.writerow(r)
                # sys.exit()
    return outlist
